library(quanteda)
library(conText)
library(dplyr)
library(text2vec)
source("utils.R")

set.seed(123L)

# Define the list of countries and target vectors
countries <- c("djazairess", "maghress", "masress", "sauress", "turess")
first_ar  <- "المعارضة"
second_ar <- "الدعم"

# New bootstrap function using get_similarity_scores2 on pre-generated tokens
process_cos_sim_bootstrap <- function(country_name, local_glove, local_transform, B = 10) {
  
  # Read the pre-generated tokens object (assumed to have been generated with docvars like "yearwk")
  target_toks <- readRDS(paste0("data/analysis_toks/", country_name, "_target_toks_leader.rds"))
  
  # (If necessary, you can check or assign a document ID here—but ideally your tokens already have docvars)
  # For example, if you want to bootstrap by a time grouping variable (e.g., "yearwk"), make sure it's there.
  
  # Create an empty list to store bootstrap results
  bootstrap_results <- vector("list", B)
  
  for (i in seq_len(B)) {
    # Sample document indices with replacement from the tokens object
    boot_indices <- sample(seq_len(ndoc(target_toks)), replace = TRUE)
    
    # Subset the tokens object using the bootstrapped indices
    boot_toks <- target_toks[boot_indices]
    
    # Save the bootstrapped tokens to a temporary file for get_similarity_scores2
    temp_file <- tempfile(fileext = ".rds")
    saveRDS(boot_toks, temp_file)
    
    # Compute cosine similarities using the tokens-based function.
    # Here group_var is set to "yearwk" (or change to "doc_id" if you prefer),
    # provided that your tokens objects have that docvar.
    boot_res <- get_similarity_scores2(
      target_toks_file = temp_file,
      target        = "TARGETWORD",
      first_vec     = first_ar, 
      second_vec    = second_ar, 
      pre_trained   = local_glove,
      transform_matrix = local_transform,
      group_var     = "yearwk",  # Change this if a different grouping is desired.
      window        = 12L,
      norm          = "l2"
    )
    
    # Optionally add a bootstrap iteration identifier
    boot_res$bootstrap_iter <- i
    bootstrap_results[[i]] <- boot_res
    
    cat("Completed bootstrap iteration", i, "for country", country_name, "\n")
  }
  
  # Combine all bootstrap iterations into one data frame
  boot_df <- bind_rows(bootstrap_results)
  
  # Compute summary statistics over the bootstrap distribution for each group:
  stats_summary <- boot_df %>%
    group_by(group) %>%
    summarise(
      mean     = mean(val, na.rm = TRUE),
      median   = median(val, na.rm = TRUE),
      se       = sd(val, na.rm = TRUE),
      lower_ci = quantile(val, probs = 0.025, na.rm = TRUE),
      upper_ci = quantile(val, probs = 0.975, na.rm = TRUE)
    )
  
  # Return a list containing both the full bootstrap distribution and the summary statistics
  return(list(bootstrap_distribution = boot_df, stats_summary = stats_summary))
}

# Bootstrapped processing function for a specific embedding version
process_version_bootstrap <- function(version, B = 1000) {
  # Read the combined embedding for the chosen version
  local_transform <- readRDS(paste0("data/embedding_combined/combined_local_transform", version, ".rds"))
  local_glove     <- readRDS(paste0("data/embedding_combined/combined_local_glove", version, ".rds"))
  
  # Loop over each country and run the bootstrap procedure
  results <- lapply(countries, function(country) {
    process_cos_sim_bootstrap(country, local_glove, local_transform, B = B)
  })
  
  # Save the results for each country in its respective output directory
  for (i in seq_along(countries)) {
    out_dir <- paste0("data/output/robustness/cos_sims_bootstrap/", countries[i])
    if (!dir.exists(out_dir)) {
      dir.create(out_dir, recursive = TRUE)
    }
    
    # Save the full bootstrap distribution and the summary statistics
    saveRDS(results[[i]]$bootstrap_distribution,
            file = paste0(out_dir, "/", "cos_sims_bootstrap_", version, ".rds"))
    saveRDS(results[[i]]$stats_summary,
            file = paste0(out_dir, "/", "stats_summary_", version, ".rds"))
  }
}

# Define your embedding version(s)
sample_sizes <- c(1.5e6)
sample_sizes <- sapply(sample_sizes, function(x) format(x, scientific = FALSE))
versions <- paste0(sample_sizes, "30k")

# Run the bootstrap procedure for each version (using B = 10 as a simple example)
lapply(versions, process_version_bootstrap, B = 10)